{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d70339c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy.stats import wasserstein_distance\n",
    "import helpers as ph\n",
    "import seaborn as sns\n",
    "import dataframe_image as dfi\n",
    "\n",
    "COLOR_STR = \"#0A3EA4,#4874F9,#84A0F5,#F1F4FB,#FFFFFF\"\n",
    "palette = sns.color_palette(f\"blend:{COLOR_STR}\", 12, as_cmap=True)\n",
    "\n",
    "styles = ph.VIS_STYLES"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "406e94ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULTS_DIR = f'./data/distributions/'\n",
    "CONTEXT = 'default'\n",
    "SAVEFIG = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ca3c5b2",
   "metadata": {},
   "source": [
    "## Load human and LM opinion distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4f60044",
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_df, human_df = [], []\n",
    "for wave in ph.PEW_SURVEY_LIST:\n",
    "    SURVEY_NAME = f'American_Trends_Panel_W{wave}'\n",
    "\n",
    "    cdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_combined.csv'))\n",
    "    cdf['survey'] = f'ATP {wave}'\n",
    "    combined_df.append(cdf)\n",
    "    \n",
    "    hdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_baseline.csv'))\n",
    "    hdf['survey'] = f'ATP {wave}'\n",
    "    human_df.append(hdf)\n",
    "combined_df, human_df = pd.concat(combined_df), pd.concat(human_df)\n",
    "combined_df['Source'] = combined_df.apply(lambda x: 'AI21 Labs' if 'j1-' in x['model_name'].lower() else 'OpenAI',\n",
    "                                          axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed75a60c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print('# Questions:', len(set(combined_df['question'])))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "364d4f88",
   "metadata": {},
   "source": [
    "## Compute average representativeness across dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08f1e19f",
   "metadata": {},
   "outputs": [],
   "source": [
    "KEYS = ['Source', 'model_name', 'attribute', 'group', 'group_order', 'model_order']\n",
    "\n",
    "grouped = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \\\n",
    "         .sort_values(by=['model_order', 'group_order'])\n",
    "grouped['Rep'] = 1 - grouped['WD']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "605157ad",
   "metadata": {},
   "source": [
    "### Overall representativeness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "256c6dce",
   "metadata": {},
   "outputs": [],
   "source": [
    "human_baseline = human_df.groupby(['group_x'], as_index=False).agg({'WD': np.mean})\n",
    "human_baseline['Rep'] = 1 - human_baseline['WD']\n",
    "human_baseline = human_baseline.agg({'Rep': (np.mean, min)}).reset_index()\n",
    "human_baseline['model_name'] = human_baseline.apply(lambda x: 'Avg' if x['index'] == 'mean' \\\n",
    "                                                    else 'Worst', axis=1)\n",
    "human_baseline['model_order'] = -1\n",
    "human_baseline['Source'] = \"Humans\"\n",
    "\n",
    "\n",
    "g = pd.concat([human_baseline, grouped[grouped['attribute'] == 'Overall']]).rename(columns={'model_name': '',\n",
    "                                                                                            'Rep': 'R'})\n",
    "\n",
    "table = pd.pivot_table(g, \n",
    "                       columns=['Source', ''], \n",
    "                       values='R', \n",
    "                       sort=False)\n",
    "table_vis = table.style.background_gradient(palette, axis=1).set_table_styles(styles)  \\\n",
    "                        .set_properties(**{\"font-size\":\"0.75rem\"}).format(precision=3)\n",
    "\n",
    "if SAVEFIG: table_vis.hide_index().export_png('./figures/representativeness.png')\n",
    "display(table_vis)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e8f7beb",
   "metadata": {},
   "source": [
    "### Subgroup representativeness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1affeb11",
   "metadata": {},
   "outputs": [],
   "source": [
    "styles[-1]['props'][-1] = (styles[-1]['props'][-1][0], \"105%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8e8da75",
   "metadata": {},
   "outputs": [],
   "source": [
    "for attribute in ph.DEMOGRAPHIC_ATTRIBUTES[1:]:\n",
    "    \n",
    "    print(f'-----{attribute}----')\n",
    "    \n",
    "    g = grouped[grouped['attribute'] == attribute].rename(columns={'model_name': 'Model', 'group': attribute,\n",
    "                                                                  'Source': ''})\n",
    "\n",
    "    table = pd.pivot_table(g, \n",
    "                           index=[attribute], \n",
    "                           columns=['', 'Model'], \n",
    "                           values=\"Rep\", \n",
    "                           sort=False)\n",
    "    table_vis = table.style.background_gradient(palette, axis=(attribute=='Overall')).set_table_styles(styles)  \\\n",
    "                            .set_properties(**{\"font-size\":\"1.3rem\"}).format(precision=3)\n",
    "    if SAVEFIG: table_vis.export_png(f'./figures/representativeness_{attribute}.png')\n",
    "\n",
    "    display(table_vis)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
